#define patchSide 8
#define patchSide2 (patchSide*patchSide)

#define WGS_W 32
#define WGS_H 8

#define N 16
#define NSHIFT 4

__kernel __attribute__((reqd_work_group_size(WGS_W, WGS_H, 1)))
void patches(
             __global unsigned*   w_ind,
             __global unsigned*   h_ind,
             int                  w_ind_size,
             int                  h_ind_size,
             __global float*      img_decomposed,
             int                  width,
             int                  height,
             int                  searchWindow,
             float                threshold,
             __global unsigned*   offsets
             )
{
    int ind_i = get_global_id(0);
    int ind_j = get_global_id(1);

    if(ind_i >= w_ind_size)
        return;

    if(ind_j >= h_ind_size)
        return;

    int i = w_ind[ind_i];
    int j = h_ind[ind_j];

    int local_i = get_local_id(0);
    int local_j = get_local_id(1);
    
    int local_idx = local_j*WGS_W + local_i;
    
    typedef struct dist_t { unsigned offset; float dist; } dist_t;
    
    __local dist_t buf[WGS_W*WGS_H][N];
    __local dist_t* _dists = buf[local_idx];
    
    // insert original patch first
    const int offsetOrg = j*width + i;
#pragma unroll
    for(int i = 0; i < N; i++)
    {
        _dists[i].offset = offsetOrg;
        _dists[i].dist = 0;
    }
    int _outSize = 1;
    //
    
    int offset2 = (j*width + i)*patchSide2;
    
    float patch[patchSide2];
    for(int k = 0; k < patchSide2; k++)
        patch[k] = img_decomposed[offset2 + k];
    
    const int halfSearchWindow = searchWindow>>1;
    for(int dj = -halfSearchWindow; dj <= halfSearchWindow; dj++)
    {
        int y = j + dj;
        for(int di = -halfSearchWindow; di <= halfSearchWindow; di++)
        {
            int x = i + di;
            
            // ssd
            float dist = 0;
            int offset1 = (y*width + x)*patchSide2;

            for(int k = 0; k < patchSide2; k++)
            {
                float p = img_decomposed[offset1 + k];
                float p1 = patch[k];
                dist += (p - p1)*(p - p1);
            }

            // insert into the offsets table
            if(dist < threshold)
            {
                unsigned offset = y*width + x;
                
                int insertPos = -1;
                for(int kk = 1; kk < _outSize; kk++)
                {
                    // ищем место для вставки
                    if(dist < _dists[kk].dist)
                    {
                        insertPos = kk;
                        if(_outSize < N)
                            _outSize++;
                        for(int kkk = _outSize-1; kkk > kk; kkk--)
                            _dists[kkk] = _dists[kkk-1];
                        break;
                    }
                }
                if(insertPos >= 0)
                {
                    _dists[insertPos].offset = offset;
                    _dists[insertPos].dist = dist;
                }
                else
                if(_outSize < N)
                {
                    insertPos = _outSize++;
                    _dists[insertPos].offset = offset;
                    _dists[insertPos].dist = dist;
                }
            }
            //
        }
    }

    int offs = (ind_j*w_ind_size + ind_i)<<(NSHIFT);
    __global unsigned* g_offsets = offsets + offs;
#pragma unroll
    for(int i = 0; i < N; i++)
        g_offsets[i] = _dists[i].offset;
}
